# -*- coding:utf-8 -*-

import tensorflow as tf

"""
学习率相关
"""


def get_learning_rate(global_step, lr_config):
    """
    学习率选择
    :param global_step:
    :param lr_config: 学习类配置
    :return:
    """
    # 固定学习率
    if lr_config.name == 'fixed':
        lr = tf.constant(lr_config.value, dtype=tf.float32)

    # 指数衰减学习率 learning x decay_rate ^ step/decay_step
    elif lr_config.name == 'exp_decay':
        lr = tf.train.exponential_decay(
            lr_config.value, global_step, lr_config.decay_step,
            lr_config.decay_rate, staircase=True)
    else:
        raise ValueError('不支持该类型学习率')
    return lr
